#!/usr/bin/env python3

#
# Imports
#

import sys
import csv
import argparse
import datetime
import pickle
import os.path

from collections import defaultdict

import numpy as np
import scipy.optimize
import scipy.misc

#
# Consistency check functions
#

def compute_ercs_array(winner, violations, normalize=True):
    """Compute the ercs for a given winner.
    Arguments:
        winner: string, the winning candidate.
        violations: dictionary, contender to violation profiles mapping.
        normalize: boolean, whether to return an array of elements in {0,1,-1}. Defaults to True.
    Returns an np.array of shape |contenders-1| x |constraints|.
    """
    winner_profile = violations[winner]
    profiles = np.asarray(list(violations.values()))
    ercs_array = profiles - winner_profile[None, :]  
    ercs_array = ercs_array[(ercs_array != 0).any(1), :] # trim empty ercs (there should be exactly one)
    if normalize:
        ercs_array = (ercs_array > 0).astype('int') - (ercs_array < 0).astype('int')
    return ercs_array


def consistent_ot_array(ercs_array):
    """Check whether a set of ercs in np.array format is compatible with a ranking.
    Arguments:
        ercs_array: np.array of shape N x |CON|, contains -1 (L), 1 (W), 0
    Returns a boolean.
    """
    if ercs_array.shape[0] == 0:
        return True
    l = (ercs_array < 0).any(0) # ercs with at least an L
    w = (ercs_array > 0).any(0) # ercs with at least a W
    if not np.sum(l) > 0: # no L at all
        return True
    elif not np.sum(~l & w) > 0: # no W that isn't an L for another element
        return False
    else:
        mask = (l[None, :] >= (ercs_array > 0)).all(1) # Riggle's book has a mistake here ?
        return consistent_ot_array(ercs_array[mask])


def consistent_hg_array(ercs_array):
    """Check whether a set of ercs in np.array format is compatible with a weighting.
    Arguments:
        ercs_array: np.array of shape N x |CON|, contains violation differences.
    Returns a boolean.
    """
    N, n = ercs_array.shape
    if N == 0:
        return True
    opt_res = scipy.optimize.linprog(np.zeros(n), A_ub=-ercs_array, b_ub=-np.ones(N), bounds=(0, None))
    return opt_res.success
 

def consistent_ot(mapping, gen, return_ercs=False):
    """Check whether a language is consistent with at least one ranking.
    Arguments:
        mapping: dictionary, mappings from underlying forms to winning candidates.
        gen: dictionary mapping UFs to a dictionary mapping SFs to violation profiles.
        return_ercs: boolean, whether to return the array of ercs.
    Returns a boolean, or a (boolean, array) pair.
    """
    ercs_list = []
    for form, winner in mapping.items():
        ercs_array_form = compute_ercs_array(winner, gen[form])
        ercs_list.append(ercs_array_form)
    ercs_array = np.vstack(ercs_list)
    res = consistent_ot_array(ercs_array)
    if not return_ercs:
        return res
    else:
        return (res, ercs_array)


def consistent_hg(mapping, gen, return_ercs=False):
    """Check whether a language is consistent with at least one weighting.
    Arguments:
        mapping: dictionary, mappings from underlying forms to surface forms.
        gen: dictionary mapping UFs to a dictionary mapping SFs to violation profiles.
        return_ercs: boolean, whether to return the array of ercs.
    Returns a boolean, or a (boolean, array) pair.
    """
    ercs_list = []
    for form, winner in mapping.items():
        ercs_array_form = compute_ercs_array(winner, gen[form], False)
        ercs_list.append(ercs_array_form)
    ercs_array = np.vstack(ercs_list)
    res = consistent_hg_array(ercs_array)
    if not return_ercs:
        return res
    else:
        return (res, ercs_array)

#
# Other useful functions.
#

def generate_permutations(n):
    """Generate permutations of a certain as a NumPy array.
    Arguments:
        n: integer, set size.
    Returns a np.array of shape n! x n containing all permutations of range(n).
    """
    if n == 1:
        return np.asarray([[0]])
    else:
        perms = generate_permutations(n-1)
        end_result = []
        for i in range(n):
            new_perms = np.hstack((perms[:, :i], np.full((perms.shape[0], 1), n-1, dtype='int'), perms[:, i:]))
            end_result.append(new_perms)
        return np.vstack(end_result)

def rankings_to_weights(rankings, Delta, delta):
    """Compute the HG weights associated to the OT ranking.
    Arguments:
        rankings: np.array of shape N x n containing permutations of indices in range(n).
        Delta: float, parameter Delta from OTvsHG.pdf.
        delta: float, parameter delta from OTvsHG.pdf.
    Returns an np.array of shape N x n containing normalized weights.
    """
    result = ((Delta + delta) / delta) ** (rankings.shape[1] - rankings)
    result /= np.sum(result[0, :])
    return result

def effective_dimension(weights, n, delta):
    """Compute the effective dimension of an array of weights.
    Arguments:
        weights: np.array, one-dimensional.
        n: integer.
        delta: float.
    Returns a float.
    """
    sorted_weights = weights[np.argsort(-weights)]
    T = len(weights)
    values = np.arange(T+1) + (np.r_[np.cumsum(sorted_weights[::-1])[::-1], 0] ** 2) * 2 * np.log(n) / (delta**2)
    return np.amin(values)

def optimal_weights(ercs, normalize=False):
    """Compute the optimal weights (wrt margin) for feasible ERCs.
    Arguments:
        ercs: np.ndarray, N x |CON|.
        normalize: bool, whether to replace all values by ±1.
    Returns an np.ndarray of size |CON|.
    """
    if normalize:
        ercs = (ercs > 0).astype('int') - (ercs < 0).astype('int')
    N, d = ercs.shape
    obj = np.r_[-1, np.zeros(d)]
    ub_matrix = np.hstack((np.ones((N, 1)), -ercs))
    ub_matrix = np.vstack((ub_matrix, np.r_[0, np.ones(d)]))
    ub_vector = np.r_[np.zeros(N), 1]
    opt_result = scipy.optimize.linprog(c=obj, A_ub=ub_matrix, b_ub=ub_vector)
    assert(opt_result.success)
    w = opt_result.x[1:]
    return w / np.sum(w)
    
if __name__ == '__main__':
    
    #
    # Command-line arguments.
    #
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', type=int, metavar='N', help='maximum number of data points')
    parser.add_argument('-d', '--distribution', metavar='DIST', choices=['uniform', 'zipfian'], default='uniform', help='sampling distribution')
    parser.add_argument('-a', metavar='A', type=float, default=1, help='zipfian distribution parameter')
    parser.add_argument('-r', metavar='R', type=int, default=1, help='number of times simulation is repeated')
    parser.add_argument('-o', '--output', metavar='output_file.npz', default=None, help='name of file where to save the results')
    parser.add_argument('-t', '--typology', metavar='typology.pkl', help='name of file where to save and/or load the typology')
    parser.add_argument('--trick', action='store_true')
    # parser.add_argument('-l', '--losers', action='store_true', help='include pairs of losers')
    parser.add_argument('csv_filename', metavar='input_file.csv', help='name of CSV file')
    args = parser.parse_args()
    
    #
    # Source data processing.
    # Data is taken to be a csv file with a header and the following row structure:
    #     underlying form, surface form, violations #1, violations #2, etc.
    #

    generating_function = defaultdict(dict)
    
    underlying_forms_list = []

    with open(args.csv_filename, 'r', newline='') as csv_file:
        reader = csv.reader(csv_file)
        constraint_list = next(reader)[2:]
        for row in reader: # underlying form 
            violations = np.asarray([float(x) for x in row[2:]])
            uf = str(row[0])
            sf = str(row[1])
            generating_function[uf][sf] = violations
            underlying_forms_list.append(uf)

    num_constraints = len(constraint_list)

    num_underlying_forms = len(generating_function)


    #
    # Typology computation.
    # We compute two lists, ot_typology_list and hg_typology_list.
    # They are list of pairs (ranking, mapping).
    # mapping is a dictionary: underlying_form -> winner.
    # ranking is an np.array of ercs.
    #
    
    if args.typology is not None and os.path.isfile(args.typology):
        with open(args.typology, 'rb') as pkl:
            ot_typology_list, hg_typology_list, ot_mask = pickle.load(pkl)
            
        print('Loaded {} HG grammars of which {} are OT.\n'.format(len(hg_typology_list), len(ot_typology_list)))
    else:
        ot_typology_list = [(np.empty((0, num_constraints)), dict())]

        for i, (form, candidates) in enumerate(generating_function.items()):
            print('{}: processing form \'{}\' ({:03d} / {:03d} -- OT).'.format(datetime.datetime.now(), form, i+1, num_underlying_forms))
            new_typology = []
            for ercs_array, mapping in ot_typology_list:
                for candidate in candidates.keys():
                    new_mapping = mapping.copy()
                    new_mapping[form] = candidate
                    consistent, new_ercs = consistent_ot(new_mapping, generating_function, True)
                    if consistent:
                        new_typology.append((new_ercs, new_mapping))
                
            ot_typology_list = new_typology

        hg_typology_list = [(np.empty((0, num_constraints)), dict())]

        for i, (form, candidates) in enumerate(generating_function.items()):
            print('{}: processing form \'{}\' ({:03d} / {:03d} -- HG).'.format(datetime.datetime.now(), form, i+1, num_underlying_forms))
            new_typology = []
            for ercs_array, mapping in hg_typology_list:
                for candidate in candidates.keys():
                    new_mapping = mapping.copy()
                    new_mapping[form] = candidate
                    consistent, new_ercs = consistent_hg(new_mapping, generating_function, True)
                    if consistent:
                        new_typology.append((new_ercs, new_mapping))
                
            hg_typology_list = new_typology

        print('Found {} HG grammars of which {} are OT.\n'.format(len(hg_typology_list), len(ot_typology_list)))
        
        ot_mask = np.zeros(len(hg_typology_list), dtype='bool')
        
        for i, (ranking, mapping) in enumerate(hg_typology_list):
            for ot_ranking, ot_mapping in ot_typology_list:
                if ot_mapping == mapping:
                    ot_mask[i] = True
                    break
                
        if args.typology is not None:
            with open(args.typology, 'wb') as pkl:
                pickle.dump((ot_typology_list, hg_typology_list, ot_mask), pkl)
                assert(False)
    
    #
    # Learning simulation.
    #
    
    np.random.seed(1234)
    
    Delta = np.amax([np.amax([np.amax(v) for v in sf2v.values()]) for sf2v in generating_function.values()])
    delta = 1.0
    
    all_opt_weights = np.vstack([optimal_weights(ercs) for ercs, mapping in hg_typology_list])
    active_constraints = num_constraints - np.sum(all_opt_weights == 0, 1)
    if args.trick:
        ot_weights = all_opt_weights[ot_mask, :]
    else:
        ot_rankings = generate_permutations(num_constraints)
        ot_weights = rankings_to_weights(ot_rankings, Delta, delta)
    
    margins_ot = np.zeros((len(hg_typology_list), args.n-2, args.r))
    margins_hg = np.zeros_like(margins_ot)
    errors_ot = np.zeros_like(margins_ot)
    errors_hg = np.zeros_like(margins_ot)
    weights_ot = np.zeros((len(hg_typology_list), args.n-2, args.r, num_constraints))
    weights_hg = np.zeros_like(weights_ot)
    edim_ot = np.zeros_like(margins_ot)
    edim_hg = np.zeros_like(margins_ot)
    
    for grammar_k, (ranking, mapping) in enumerate(hg_typology_list):
        print('{}: processing grammar #{} ({}).'.format(datetime.datetime.now(), grammar_k+1, 'OT' if ot_mask[grammar_k] else 'HG\\OT'))
        
        #
        # Generate all possible data points with their probability.
        #
        
        all_ercs_with_p = []
        if args.distribution == 'uniform':
            uf_p = np.ones(num_underlying_forms)
        elif args.distribution == 'zipfian':
            uf_p = np.arange(1, num_underlying_forms+1) ** (-parser.a)
        uf_p /= np.sum(uf_p)
        for i, uf in enumerate(underlying_forms_list):
            winner = mapping[uf]
            contender_profiles = np.asarray([v for sf, v in generating_function[uf].items() if sf != winner])
            contender_p = np.ones(contender_profiles.shape[0])
            contender_p /= len(contender_p)
            contender_p *= uf_p[i] # P(uf, sf) = P(uf) * (1 / (|Gen(uf)| - 1))
            contender_ercs = contender_profiles - generating_function[uf][winner][None, :]
            all_ercs_with_p.append(np.hstack((contender_p[:, None], contender_ercs)))
        all_ercs_with_p = np.vstack(all_ercs_with_p)
        num_ercs = all_ercs_with_p.shape[0]
        
        #
        # Learning simulation proper.
        #
        
        for n in range(3, args.n+1):
            for m in range(args.r):
                
                #
                # Pick training data.
                #
                
                data_points_inds = np.random.choice(num_ercs, n, p=all_ercs_with_p[:, 0])
                data_points_ercs = all_ercs_with_p[:, 1:][data_points_inds, :]
                
                #
                # Compute optimal OT ranking and generalization error.
                #
                
                empirical_margins_ot = np.min(np.dot(ot_weights, data_points_ercs.T), 1)
                opt_k = np.argmax(empirical_margins_ot)
                opt_empirical_margin_ot = empirical_margins_ot[opt_k]
                
                opt_ot_labels = np.dot(all_ercs_with_p[:, 1:], ot_weights[opt_k, :])
                generalization_error_ot = np.dot(all_ercs_with_p[:, 0], opt_ot_labels < 0)
                
                #
                # Compute optimal HG weights and generalization error.
                #
                
                opt_hg_weights = optimal_weights(data_points_ercs)
                opt_empirical_margin_hg = np.min(np.dot(data_points_ercs, opt_hg_weights))
                opt_hg_labels = np.dot(all_ercs_with_p[:, 1:], opt_hg_weights)
                generalization_error_hg = np.dot(all_ercs_with_p[:, 0], opt_hg_labels < 0)
                
                margins_ot[grammar_k, n-3, m] = opt_empirical_margin_ot
                margins_hg[grammar_k, n-3, m] = opt_empirical_margin_hg
                errors_ot[grammar_k, n-3, m] = generalization_error_ot
                errors_hg[grammar_k, n-3, m] = generalization_error_hg
                weights_ot[grammar_k, n-3, m, :] = ot_weights[opt_k, :]
                weights_hg[grammar_k, n-3, m, :] = opt_hg_weights
                edim_ot[grammar_k, n-3, m] = effective_dimension(weights_ot[grammar_k, n-3, m, :], n, margins_ot[grammar_k, n-3, m])
                edim_hg[grammar_k, n-3, m] = effective_dimension(weights_hg[grammar_k, n-3, m, :], n, margins_hg[grammar_k, n-3, m])
    
    if args.output is not None:
        np.savez(args.output, margins_ot=margins_ot, margins_hg=margins_hg, errors_ot=errors_ot, errors_hg=errors_hg,
            ot_mask=ot_mask, weights_ot=weights_ot, weights_hg=weights_hg, edim_ot=edim_ot, edim_hg=edim_hg, active_constraints=active_constraints)
                
                
